diff --git a/wasm-msg/src/lib.rs b/wasm-msg/src/lib.rs index 375fb12..28eb73d 100644 --- a/wasm-msg/src/lib.rs +++ b/wasm-msg/src/lib.rs @@ -7,7 +7,6 @@ pub use paste::paste; pub mod memory; pub mod message; pub mod sync; -pub mod tls; pub use sync::WasmResult; diff --git a/wasm-msg/src/tls.rs b/wasm-msg/src/tls.rs deleted file mode 100644 index 63f4268..0000000 --- a/wasm-msg/src/tls.rs +++ /dev/null @@ -1,45 +0,0 @@ -use core::cell::UnsafeCell; - -const MAX_CONCURRENT_THREADS: usize = 16; - -#[link(wasm_import_module = "wasm_msg")] -extern "C" { - fn wasm_msg_current_thread_id() -> usize; -} - -pub struct ThreadLocalStorage { - storage: [UnsafeCell>; MAX_CONCURRENT_THREADS], -} -// SAFETY: ThreadLocalStorage is designed to be thread-safe when accessed through -// the Context's thread_id. Each thread gets its own slot based on thread_id, -// so there's no data race between threads. The Option wrapper ensures we can -// initialize it lazily. -unsafe impl Sync for ThreadLocalStorage {} - -impl ThreadLocalStorage { - pub fn get(&self) -> &mut T { - self.get_or_init(T::default) - } -} - -impl ThreadLocalStorage { - #[allow(clippy::mut_from_ref)] - fn get_slot(&self, slot: usize) -> &mut Option { - if slot >= MAX_CONCURRENT_THREADS { - panic!("Thread ID out of bounds"); - } - unsafe { &mut *self.storage[slot].get() } - } - - pub const fn new() -> Self { - Self { - storage: [const { UnsafeCell::new(None) }; MAX_CONCURRENT_THREADS], - } - } - - #[allow(clippy::mut_from_ref)] - pub fn get_or_init(&self, init: impl FnOnce() -> T) -> &mut T { - let thread_id = unsafe { wasm_msg_current_thread_id() }; - self.get_slot(thread_id).get_or_insert_with(init) - } -} diff --git a/wasm/go-host/resolver_api.go b/wasm/go-host/resolver_api.go index b0bfadd..7ade18b 100644 --- a/wasm/go-host/resolver_api.go +++ b/wasm/go-host/resolver_api.go @@ -35,20 +35,6 @@ type ResolverApi struct { func NewResolverApi(ctx context.Context, runtime wazero.Runtime, wasmBytes []byte) *ResolverApi { // Register host functions as a separate module _, err := runtime.NewHostModuleBuilder("wasm_msg"). - NewFunctionBuilder(). - WithFunc(func(ctx context.Context, mod api.Module, ptr uint32) uint32 { - // log_resolve: ignore payload, return Void - response := &messages.Response{Result: &messages.Response_Data{Data: mustMarshal(&messages.Void{})}} - return transferResponse(mod, response) - }). - Export("wasm_msg_host_log_resolve"). - NewFunctionBuilder(). - WithFunc(func(ctx context.Context, mod api.Module, ptr uint32) uint32 { - // log_assign: ignore payload, return Void - response := &messages.Response{Result: &messages.Response_Data{Data: mustMarshal(&messages.Void{})}} - return transferResponse(mod, response) - }). - Export("wasm_msg_host_log_assign"). NewFunctionBuilder(). WithFunc(func(ctx context.Context, mod api.Module, ptr uint32) uint32 { // Return current timestamp @@ -65,11 +51,6 @@ func NewResolverApi(ctx context.Context, runtime wazero.Runtime, wasmBytes []byt return transferResponse(mod, response) }). Export("wasm_msg_host_current_time"). - NewFunctionBuilder(). - WithFunc(func(ctx context.Context, mod api.Module) uint32 { - return 0 - }). - Export("wasm_msg_current_thread_id"). Instantiate(ctx) if err != nil { panic(fmt.Sprintf("Failed to register host functions: %v", err)) @@ -92,9 +73,8 @@ func NewResolverApi(ctx context.Context, runtime wazero.Runtime, wasmBytes []byt wasmMsgFree := instance.ExportedFunction("wasm_msg_free") wasmMsgGuestSetResolverState := instance.ExportedFunction("wasm_msg_guest_set_resolver_state") wasmMsgGuestResolve := instance.ExportedFunction("wasm_msg_guest_resolve") - wasmMsgGuestResolveSimple := instance.ExportedFunction("wasm_msg_guest_resolve_simple") - if wasmMsgAlloc == nil || wasmMsgFree == nil || wasmMsgGuestSetResolverState == nil || wasmMsgGuestResolve == nil || wasmMsgGuestResolveSimple == nil { + if wasmMsgAlloc == nil || wasmMsgFree == nil || wasmMsgGuestSetResolverState == nil || wasmMsgGuestResolve == nil { panic("Required WASM exports not found") } @@ -106,7 +86,6 @@ func NewResolverApi(ctx context.Context, runtime wazero.Runtime, wasmBytes []byt wasmMsgFree: wasmMsgFree, wasmMsgGuestSetResolverState: wasmMsgGuestSetResolverState, wasmMsgGuestResolve: wasmMsgGuestResolve, - wasmMsgGuestResolveSimple: wasmMsgGuestResolveSimple, } } @@ -169,26 +148,6 @@ func (r *ResolverApi) Resolve(request *resolver.ResolveFlagsRequest) (*resolver. return response, nil } -func (r *ResolverApi) ResolveSimple(request *messages.ResolveSimpleRequest) (*resolver.ResolvedFlag, error) { - ctx := context.Background() - reqPtr := r.transferRequest(request) - - results, err := r.wasmMsgGuestResolveSimple.Call(ctx, uint64(reqPtr)) - if err != nil { - return nil, fmt.Errorf("failed to call wasm_msg_guest_resolve_simple: %w", err) - } - - respPtr := uint32(results[0]) - - response := &resolver.ResolvedFlag{} - err = r.consumeResponse(respPtr, response) - if err != nil { - return nil, err - } - - return response, nil -} - // transferRequest transfers a protobuf message to WASM memory func (r *ResolverApi) transferRequest(message proto.Message) uint32 { data := mustMarshal(message) diff --git a/wasm/java-host/src/main/java/com/spotify/confidence/wasmresolvepoc/ResolverApi.java b/wasm/java-host/src/main/java/com/spotify/confidence/wasmresolvepoc/ResolverApi.java index c67a359..1fb8824 100644 --- a/wasm/java-host/src/main/java/com/spotify/confidence/wasmresolvepoc/ResolverApi.java +++ b/wasm/java-host/src/main/java/com/spotify/confidence/wasmresolvepoc/ResolverApi.java @@ -9,18 +9,14 @@ import com.dylibso.chicory.wasm.WasmModule; import com.dylibso.chicory.wasm.types.FunctionType; import com.dylibso.chicory.wasm.types.ValType; -import com.google.protobuf.ByteString; import com.google.protobuf.GeneratedMessage; import com.google.protobuf.InvalidProtocolBufferException; import com.google.protobuf.Timestamp; import com.spotify.confidence.flags.resolver.v1.ResolveFlagsRequest; import com.spotify.confidence.flags.resolver.v1.ResolveFlagsResponse; -import com.spotify.confidence.flags.resolver.v1.ResolvedFlag; import rust_guest.Messages.SetResolverStateRequest; import rust_guest.Messages; -import rust_guest.Messages.ResolveSimpleRequest; -import rust_guest.Types; import java.util.List; import java.util.function.Function; @@ -37,16 +33,12 @@ public class ResolverApi { // api private final ExportFunction wasmMsgGuestSetResolverState; private final ExportFunction wasmMsgGuestResolve; - private final ExportFunction wasmMsgGuestResolveSimple; public ResolverApi(WasmModule module) { instance = Instance.builder(module) .withImportValues(ImportValues.builder() .addFunction(createImportFunction("current_time", Messages.Void::parseFrom, this::currentTime)) - .addFunction(createImportFunction("log_resolve", Types.LogResolveRequest::parseFrom, this::logResolve)) - .addFunction(createImportFunction("log_assign", Types.LogAssignRequest::parseFrom, this::logAssign)) - .addFunction(new ImportFunction("wasm_msg", "wasm_msg_current_thread_id", FunctionType.of(List.of(), List.of(ValType.I32)), (instance1, args) -> new long[]{0})) .build()) .withMachineFactory(MachineFactoryCompiler::compile) .build(); @@ -54,16 +46,6 @@ public ResolverApi(WasmModule module) { wasmMsgFree = instance.export("wasm_msg_free"); wasmMsgGuestSetResolverState = instance.export("wasm_msg_guest_set_resolver_state"); wasmMsgGuestResolve = instance.export("wasm_msg_guest_resolve"); - wasmMsgGuestResolveSimple = instance.export("wasm_msg_guest_resolve_simple"); - } - - private GeneratedMessage logAssign(Types.LogAssignRequest logAssignRequest) { - System.out.println("logAssign"); - return Messages.Void.getDefaultInstance(); - } - - private GeneratedMessage logResolve(Types.LogResolveRequest logResolveRequest) { - return Messages.Void.getDefaultInstance(); } private Timestamp currentTime(Messages.Void unused) { @@ -85,12 +67,6 @@ public ResolveFlagsResponse resolve(ResolveFlagsRequest request) { return consumeResponse(respPtr, ResolveFlagsResponse::parseFrom); } - public ResolvedFlag resolve_simple(ResolveSimpleRequest request) { - int reqPtr = transferRequest(request); - int respPtr = (int) wasmMsgGuestResolveSimple.apply(reqPtr)[0]; - return consumeResponse(respPtr, ResolvedFlag::parseFrom); - } - private T consumeResponse(int addr, ParserFn codec) { try { Messages.Response response = Messages.Response.parseFrom(consume(addr)); diff --git a/wasm/python-host/resolver_api.py b/wasm/python-host/resolver_api.py index 7ab7f92..252d2ba 100644 --- a/wasm/python-host/resolver_api.py +++ b/wasm/python-host/resolver_api.py @@ -38,18 +38,6 @@ def __init__(self, wasm_bytes: bytes): def _register_host_functions(self): """Register host functions that can be called from WASM""" - def log_resolve(ptr: int) -> int: - # Ignore payload; return Void - response = messages_pb2.Response() - response.data = messages_pb2.Void().SerializeToString() - return self._transfer_response(response) - - def log_assign(ptr: int) -> int: - # Ignore payload; return Void - response = messages_pb2.Response() - response.data = messages_pb2.Void().SerializeToString() - return self._transfer_response(response) - def current_time(ptr: int) -> int: """Host function to return current timestamp""" try: @@ -75,20 +63,11 @@ def current_time(ptr: int) -> int: # Create function type: takes one i32 parameter, returns one i32 func_type = FuncType([ValType.i32()], [ValType.i32()]) host_func_time = Func(self.store, func_type, current_time) - host_func_log_resolve = Func(self.store, func_type, log_resolve) - host_func_log_assign = Func(self.store, func_type, log_assign) linker = Linker(self.store.engine) # Define the import with module and name linker.define(self.store, "wasm_msg", "wasm_msg_host_current_time", host_func_time) - linker.define(self.store, "wasm_msg", "wasm_msg_host_log_resolve", host_func_log_resolve) - linker.define(self.store, "wasm_msg", "wasm_msg_host_log_assign", host_func_log_assign) - - # Optional: current thread id function - def current_thread_id() -> int: - return 0 - linker.define(self.store, "wasm_msg", "wasm_msg_current_thread_id", Func(self.store, FuncType([], [ValType.i32()]), current_thread_id)) # Instantiate the module with imports self.instance = linker.instantiate(self.store, self.module) diff --git a/wasm/rust-guest/src/lib.rs b/wasm/rust-guest/src/lib.rs index dfdf897..0e4c2a9 100644 --- a/wasm/rust-guest/src/lib.rs +++ b/wasm/rust-guest/src/lib.rs @@ -14,7 +14,6 @@ use rand::distr::Alphanumeric; use rand::distr::SampleString; use rand::rngs::SmallRng; use rand::SeedableRng; -use wasm_msg; use wasm_msg::wasm_msg_guest; use wasm_msg::wasm_msg_host; use wasm_msg::WasmResult; @@ -28,23 +27,26 @@ use confidence_resolver::{ proto::{ confidence::flags::admin::v1::ResolverState as ResolverStatePb, confidence::flags::resolver::v1::{ - ResolveFlagsRequest, ResolveFlagsResponse, ResolveWithStickyResponse, ResolvedFlag, Sdk, + ResolveFlagsRequest, ResolveFlagsResponse, ResolveWithStickyResponse, Sdk, }, google::{Struct, Timestamp}, }, Client, FlagToApply, Host, ResolveReason, ResolvedValue, ResolverState, }; -use proto::{ResolveSimpleRequest, Void}; +use proto::Void; -impl Into - for confidence_resolver::proto::confidence::flags::resolver::v1::events::FallthroughAssignment +impl + From + for proto::FallthroughAssignment { - fn into(self) -> proto::FallthroughAssignment { + fn from( + val: confidence_resolver::proto::confidence::flags::resolver::v1::events::FallthroughAssignment, + ) -> Self { proto::FallthroughAssignment { - rule: self.rule, - assignment_id: self.assignment_id, - targeting_key: self.targeting_key, - targeting_key_selector: self.targeting_key_selector, + rule: val.rule, + assignment_id: val.assignment_id, + targeting_key: val.targeting_key, + targeting_key_selector: val.targeting_key_selector, } } } @@ -63,16 +65,17 @@ thread_local! { }); } -impl<'a> Into for &ResolvedValue<'a> { - fn into(self) -> proto::ResolvedValue { +impl<'a> From<&ResolvedValue<'a>> for proto::ResolvedValue { + fn from(val: &ResolvedValue<'a>) -> Self { proto::ResolvedValue { flag: Some(proto::Flag { - name: self.flag.name.clone(), + name: val.flag.name.clone(), }), - reason: convert_reason(self.reason.clone()), - assignment_match: match &self.assignment_match { - None => None, - Some(am) => Some(proto::AssignmentMatch { + reason: convert_reason(val.reason), + assignment_match: val + .assignment_match + .as_ref() + .map(|am| proto::AssignmentMatch { matched_rule: Some(proto::MatchedRule { name: am.rule.clone().name, }), @@ -84,8 +87,7 @@ impl<'a> Into for &ResolvedValue<'a> { }), assignment_id: am.assignment_id.to_string(), }), - }, - fallthrough_rules: self + fallthrough_rules: val .fallthrough_rules .iter() .map(|fr| proto::FallthroughRule { @@ -185,21 +187,14 @@ wasm_msg_guest! { let resolve_request = &request.resolve_request.clone().unwrap(); let evaluation_context = resolve_request.evaluation_context.clone().unwrap(); let resolver = resolver_state.get_resolver::(resolve_request.client_secret.as_str(), evaluation_context, &ENCRYPTION_KEY)?; - resolver.resolve_flags_sticky(&request).into() + resolver.resolve_flags_sticky(&request) } fn resolve(request: ResolveFlagsRequest) -> WasmResult { let resolver_state = get_resolver_state()?; let evaluation_context = request.evaluation_context.as_ref().cloned().unwrap_or_default(); let resolver = resolver_state.get_resolver::(&request.client_secret, evaluation_context, &ENCRYPTION_KEY)?; - resolver.resolve_flags(&request).into() - } - fn resolve_simple(request: ResolveSimpleRequest) -> WasmResult { - let resolver_state = get_resolver_state()?; - let evaluation_context = request.evaluation_context.as_ref().cloned().unwrap_or_default(); - let resolver = resolver_state.get_resolver::(&request.client_secret, evaluation_context, &ENCRYPTION_KEY).unwrap(); - let resolve_result = resolver.resolve_flag_name(&request.name)?; - Ok((&resolve_result.resolved_value).into()) + resolver.resolve_flags(&request) } fn flush_logs(_request:Void) -> WasmResult { let response = LOGGER.checkpoint();