Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
79 changes: 63 additions & 16 deletions library/proc_macro/src/bridge/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,9 @@ use std::marker::PhantomData;
use std::sync::atomic::AtomicU32;

use super::*;
use crate::StandaloneLevel;
use crate::bridge::server::{Dispatcher, DispatcherTrait};
use crate::bridge::standalone::NoRustc;

macro_rules! define_client_handles {
(
Expand Down Expand Up @@ -58,7 +61,7 @@ macro_rules! define_client_handles {
}
}

impl<S> DecodeMut<'_, '_, S> for $oty {
impl<S> Decode<'_, '_, S> for $oty {
fn decode(r: &mut Reader<'_>, s: &mut S) -> Self {
$oty {
handle: handle::Handle::decode(r, s),
Expand All @@ -82,7 +85,7 @@ macro_rules! define_client_handles {
}
}

impl<S> DecodeMut<'_, '_, S> for $ity {
impl<S> Decode<'_, '_, S> for $ity {
fn decode(r: &mut Reader<'_>, s: &mut S) -> Self {
$ity {
handle: handle::Handle::decode(r, s),
Expand Down Expand Up @@ -141,7 +144,10 @@ macro_rules! define_client_side {
api_tags::Method::$name(api_tags::$name::$method).encode(&mut buf, &mut ());
$($arg.encode(&mut buf, &mut ());)*

buf = bridge.dispatch.call(buf);
buf = match &mut bridge.dispatch {
DispatchWay::Closure(f) => f.call(buf),
DispatchWay::Directly(disp) => disp.dispatch(buf),
};

let r = Result::<_, PanicMessage>::decode(&mut &buf[..], &mut ());

Expand All @@ -155,13 +161,18 @@ macro_rules! define_client_side {
}
with_api!(self, self, define_client_side);

enum DispatchWay<'a> {
Closure(closure::Closure<'a, Buffer, Buffer>),
Directly(Dispatcher<NoRustc>),
}

struct Bridge<'a> {
/// Reusable buffer (only `clear`-ed, never shrunk), primarily
/// used for making requests.
cached_buffer: Buffer,

/// Server-side function that the client uses to make requests.
dispatch: closure::Closure<'a, Buffer, Buffer>,
dispatch: DispatchWay<'a>,

/// Provided globals for this macro expansion.
globals: ExpnGlobals<Span>,
Expand All @@ -173,12 +184,33 @@ impl<'a> !Sync for Bridge<'a> {}
#[allow(unsafe_code)]
mod state {
use std::cell::{Cell, RefCell};
use std::marker::PhantomData;
use std::ptr;

use super::Bridge;
use crate::StandaloneLevel;
use crate::bridge::buffer::Buffer;
use crate::bridge::client::{COUNTERS, DispatchWay};
use crate::bridge::server::{Dispatcher, HandleStore, MarkedTypes};
use crate::bridge::{ExpnGlobals, Marked, standalone};

thread_local! {
static BRIDGE_STATE: Cell<*const ()> = const { Cell::new(ptr::null()) };
static STANDALONE: RefCell<Bridge<'static>> = RefCell::new(standalone_bridge());
pub(super) static USE_STANDALONE: Cell<StandaloneLevel> = const { Cell::new(StandaloneLevel::Never) };
}

fn standalone_bridge() -> Bridge<'static> {
let mut store = HandleStore::new(&COUNTERS);
let id = store.Span.alloc(Marked { value: standalone::Span::DUMMY, _marker: PhantomData });
let dummy = super::Span { handle: id };
let dispatcher =
Dispatcher { handle_store: store, server: MarkedTypes(standalone::NoRustc) };
Bridge {
cached_buffer: Buffer::new(),
dispatch: DispatchWay::Directly(dispatcher),
globals: ExpnGlobals { call_site: dummy, def_site: dummy, mixed_site: dummy },
}
}

pub(super) fn set<'bridge, R>(state: &RefCell<Bridge<'bridge>>, f: impl FnOnce() -> R) -> R {
Expand All @@ -199,16 +231,23 @@ mod state {
pub(super) fn with<R>(
f: impl for<'bridge> FnOnce(Option<&RefCell<Bridge<'bridge>>>) -> R,
) -> R {
let state = BRIDGE_STATE.get();
// SAFETY: the only place where the pointer is set is in `set`. It puts
// back the previous value after the inner call has returned, so we know
// that as long as the pointer is not null, it came from a reference to
// a `RefCell<Bridge>` that outlasts the call to this function. Since `f`
// works the same for any lifetime of the bridge, including the actual
// one, we can lie here and say that the lifetime is `'static` without
// anyone noticing.
let bridge = unsafe { state.cast::<RefCell<Bridge<'static>>>().as_ref() };
f(bridge)
let level = USE_STANDALONE.get();
if level == StandaloneLevel::Always
|| (level == StandaloneLevel::FallbackOnly && BRIDGE_STATE.get().is_null())
{
STANDALONE.with(|bridge| f(Some(bridge)))
} else {
let state = BRIDGE_STATE.get();
// SAFETY: the only place where the pointer is set is in `set`. It puts
// back the previous value after the inner call has returned, so we know
// that as long as the pointer is not null, it came from a reference to
// a `RefCell<Bridge>` that outlasts the call to this function. Since `f`
// works the same for any lifetime of the bridge, including the actual
// one, we can lie here and say that the lifetime is `'static` without
// anyone noticing.
let bridge = unsafe { state.cast::<RefCell<Bridge<'static>>>().as_ref() };
f(bridge)
}
}
}

Expand All @@ -228,6 +267,10 @@ pub(crate) fn is_available() -> bool {
state::with(|s| s.is_some())
}

pub(crate) fn enable_standalone(level: StandaloneLevel) {
state::USE_STANDALONE.set(level);
}

/// A client-side RPC entry-point, which may be using a different `proc_macro`
/// from the one used by the server, but can be invoked compatibly.
///
Expand Down Expand Up @@ -276,7 +319,7 @@ fn maybe_install_panic_hook(force_show_panics: bool) {
/// Client-side helper for handling client panics, entering the bridge,
/// deserializing input and serializing output.
// FIXME(eddyb) maybe replace `Bridge::enter` with this?
fn run_client<A: for<'a, 's> DecodeMut<'a, 's, ()>, R: Encode<()>>(
fn run_client<A: for<'a, 's> Decode<'a, 's, ()>, R: Encode<()>>(
config: BridgeConfig<'_>,
f: impl FnOnce(A) -> R,
) -> Buffer {
Expand All @@ -292,7 +335,11 @@ fn run_client<A: for<'a, 's> DecodeMut<'a, 's, ()>, R: Encode<()>>(
let (globals, input) = <(ExpnGlobals<Span>, A)>::decode(reader, &mut ());

// Put the buffer we used for input back in the `Bridge` for requests.
let state = RefCell::new(Bridge { cached_buffer: buf.take(), dispatch, globals });
let state = RefCell::new(Bridge {
cached_buffer: buf.take(),
dispatch: DispatchWay::Closure(dispatch),
globals,
});

let output = state::set(&state, || f(input));

Expand Down
5 changes: 3 additions & 2 deletions library/proc_macro/src/bridge/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -138,12 +138,13 @@ mod rpc;
mod selfless_reify;
#[forbid(unsafe_code)]
pub mod server;
pub(crate) mod standalone;
#[allow(unsafe_code)]
mod symbol;

use buffer::Buffer;
pub use rpc::PanicMessage;
use rpc::{DecodeMut, Encode, Reader, Writer};
use rpc::{Decode, Encode, Reader, Writer};

/// Configuration for establishing an active connection between a server and a
/// client. The server creates the bridge config (`run_server` in `server.rs`),
Expand All @@ -168,7 +169,7 @@ impl !Sync for BridgeConfig<'_> {}
#[forbid(unsafe_code)]
#[allow(non_camel_case_types)]
mod api_tags {
use super::rpc::{DecodeMut, Encode, Reader, Writer};
use super::rpc::{Decode, Encode, Reader, Writer};

macro_rules! declare_tags {
($($name:ident {
Expand Down
45 changes: 20 additions & 25 deletions library/proc_macro/src/bridge/rpc.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ pub(super) trait Encode<S>: Sized {

pub(super) type Reader<'a> = &'a [u8];

pub(super) trait DecodeMut<'a, 's, S>: Sized {
pub(super) trait Decode<'a, 's, S>: Sized {
fn decode(r: &mut Reader<'a>, s: &'s mut S) -> Self;
}

Expand All @@ -24,7 +24,7 @@ macro_rules! rpc_encode_decode {
}
}

impl<S> DecodeMut<'_, '_, S> for $ty {
impl<S> Decode<'_, '_, S> for $ty {
fn decode(r: &mut Reader<'_>, _: &mut S) -> Self {
const N: usize = size_of::<$ty>();

Expand All @@ -43,12 +43,12 @@ macro_rules! rpc_encode_decode {
}
}

impl<'a, S, $($($T: for<'s> DecodeMut<'a, 's, S>),+)?> DecodeMut<'a, '_, S>
impl<'a, S, $($($T: for<'s> Decode<'a, 's, S>),+)?> Decode<'a, '_, S>
for $name $(<$($T),+>)?
{
fn decode(r: &mut Reader<'a>, s: &mut S) -> Self {
$name {
$($field: DecodeMut::decode(r, s)),*
$($field: Decode::decode(r, s)),*
}
}
}
Expand All @@ -58,23 +58,18 @@ macro_rules! rpc_encode_decode {
fn encode(self, w: &mut Writer, s: &mut S) {
// HACK(eddyb): `Tag` enum duplicated between the
// two impls as there's no other place to stash it.
#[allow(non_upper_case_globals)]
mod tag {
#[repr(u8)] enum Tag { $($variant),* }

$(pub(crate) const $variant: u8 = Tag::$variant as u8;)*
}
#[repr(u8)] enum Tag { $($variant),* }

match self {
$($name::$variant $(($field))* => {
tag::$variant.encode(w, s);
(Tag::$variant as u8).encode(w, s);
$($field.encode(w, s);)*
})*
}
}
}

impl<'a, S, $($($T: for<'s> DecodeMut<'a, 's, S>),+)?> DecodeMut<'a, '_, S>
impl<'a, S, $($($T: for<'s> Decode<'a, 's, S>),+)?> Decode<'a, '_, S>
for $name $(<$($T),+>)?
{
fn decode(r: &mut Reader<'a>, s: &mut S) -> Self {
Expand All @@ -89,7 +84,7 @@ macro_rules! rpc_encode_decode {

match u8::decode(r, s) {
$(tag::$variant => {
$(let $field = DecodeMut::decode(r, s);)*
$(let $field = Decode::decode(r, s);)*
$name::$variant $(($field))*
})*
_ => unreachable!(),
Expand All @@ -103,7 +98,7 @@ impl<S> Encode<S> for () {
fn encode(self, _: &mut Writer, _: &mut S) {}
}

impl<S> DecodeMut<'_, '_, S> for () {
impl<S> Decode<'_, '_, S> for () {
fn decode(_: &mut Reader<'_>, _: &mut S) -> Self {}
}

Expand All @@ -113,7 +108,7 @@ impl<S> Encode<S> for u8 {
}
}

impl<S> DecodeMut<'_, '_, S> for u8 {
impl<S> Decode<'_, '_, S> for u8 {
fn decode(r: &mut Reader<'_>, _: &mut S) -> Self {
let x = r[0];
*r = &r[1..];
Expand All @@ -130,7 +125,7 @@ impl<S> Encode<S> for bool {
}
}

impl<S> DecodeMut<'_, '_, S> for bool {
impl<S> Decode<'_, '_, S> for bool {
fn decode(r: &mut Reader<'_>, s: &mut S) -> Self {
match u8::decode(r, s) {
0 => false,
Expand All @@ -146,7 +141,7 @@ impl<S> Encode<S> for char {
}
}

impl<S> DecodeMut<'_, '_, S> for char {
impl<S> Decode<'_, '_, S> for char {
fn decode(r: &mut Reader<'_>, s: &mut S) -> Self {
char::from_u32(u32::decode(r, s)).unwrap()
}
Expand All @@ -158,7 +153,7 @@ impl<S> Encode<S> for NonZero<u32> {
}
}

impl<S> DecodeMut<'_, '_, S> for NonZero<u32> {
impl<S> Decode<'_, '_, S> for NonZero<u32> {
fn decode(r: &mut Reader<'_>, s: &mut S) -> Self {
Self::new(u32::decode(r, s)).unwrap()
}
Expand All @@ -171,11 +166,11 @@ impl<S, A: Encode<S>, B: Encode<S>> Encode<S> for (A, B) {
}
}

impl<'a, S, A: for<'s> DecodeMut<'a, 's, S>, B: for<'s> DecodeMut<'a, 's, S>> DecodeMut<'a, '_, S>
impl<'a, S, A: for<'s> Decode<'a, 's, S>, B: for<'s> Decode<'a, 's, S>> Decode<'a, '_, S>
for (A, B)
{
fn decode(r: &mut Reader<'a>, s: &mut S) -> Self {
(DecodeMut::decode(r, s), DecodeMut::decode(r, s))
(Decode::decode(r, s), Decode::decode(r, s))
}
}

Expand All @@ -186,7 +181,7 @@ impl<S> Encode<S> for &[u8] {
}
}

impl<'a, S> DecodeMut<'a, '_, S> for &'a [u8] {
impl<'a, S> Decode<'a, '_, S> for &'a [u8] {
fn decode(r: &mut Reader<'a>, s: &mut S) -> Self {
let len = usize::decode(r, s);
let xs = &r[..len];
Expand All @@ -201,7 +196,7 @@ impl<S> Encode<S> for &str {
}
}

impl<'a, S> DecodeMut<'a, '_, S> for &'a str {
impl<'a, S> Decode<'a, '_, S> for &'a str {
fn decode(r: &mut Reader<'a>, s: &mut S) -> Self {
str::from_utf8(<&[u8]>::decode(r, s)).unwrap()
}
Expand All @@ -213,7 +208,7 @@ impl<S> Encode<S> for String {
}
}

impl<S> DecodeMut<'_, '_, S> for String {
impl<S> Decode<'_, '_, S> for String {
fn decode(r: &mut Reader<'_>, s: &mut S) -> Self {
<&str>::decode(r, s).to_string()
}
Expand All @@ -228,7 +223,7 @@ impl<S, T: Encode<S>> Encode<S> for Vec<T> {
}
}

impl<'a, S, T: for<'s> DecodeMut<'a, 's, S>> DecodeMut<'a, '_, S> for Vec<T> {
impl<'a, S, T: for<'s> Decode<'a, 's, S>> Decode<'a, '_, S> for Vec<T> {
fn decode(r: &mut Reader<'a>, s: &mut S) -> Self {
let len = usize::decode(r, s);
let mut vec = Vec::with_capacity(len);
Expand Down Expand Up @@ -288,7 +283,7 @@ impl<S> Encode<S> for PanicMessage {
}
}

impl<S> DecodeMut<'_, '_, S> for PanicMessage {
impl<S> Decode<'_, '_, S> for PanicMessage {
fn decode(r: &mut Reader<'_>, s: &mut S) -> Self {
match Option::<String>::decode(r, s) {
Some(s) => PanicMessage::String(s),
Expand Down
2 changes: 1 addition & 1 deletion library/proc_macro/src/bridge/selfless_reify.rs
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ macro_rules! define_reify_functions {
>(f: F) -> $(extern $abi)? fn($($arg_ty),*) -> $ret_ty {
// FIXME(eddyb) describe the `F` type (e.g. via `type_name::<F>`) once panic
// formatting becomes possible in `const fn`.
assert!(size_of::<F>() == 0, "selfless_reify: closure must be zero-sized");
const { assert!(size_of::<F>() == 0, "selfless_reify: closure must be zero-sized"); }

$(extern $abi)? fn wrapper<
$($($param,)*)?
Expand Down
Loading
Loading