Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion wasm-msg/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@ pub use paste::paste;
pub mod memory;
pub mod message;
pub mod sync;
pub mod tls;

pub use sync::WasmResult;

Expand Down
45 changes: 0 additions & 45 deletions wasm-msg/src/tls.rs

This file was deleted.

43 changes: 1 addition & 42 deletions wasm/go-host/resolver_api.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,20 +35,6 @@ type ResolverApi struct {
func NewResolverApi(ctx context.Context, runtime wazero.Runtime, wasmBytes []byte) *ResolverApi {
// Register host functions as a separate module
_, err := runtime.NewHostModuleBuilder("wasm_msg").
NewFunctionBuilder().
WithFunc(func(ctx context.Context, mod api.Module, ptr uint32) uint32 {
// log_resolve: ignore payload, return Void
response := &messages.Response{Result: &messages.Response_Data{Data: mustMarshal(&messages.Void{})}}
return transferResponse(mod, response)
}).
Export("wasm_msg_host_log_resolve").
NewFunctionBuilder().
WithFunc(func(ctx context.Context, mod api.Module, ptr uint32) uint32 {
// log_assign: ignore payload, return Void
response := &messages.Response{Result: &messages.Response_Data{Data: mustMarshal(&messages.Void{})}}
return transferResponse(mod, response)
}).
Export("wasm_msg_host_log_assign").
NewFunctionBuilder().
WithFunc(func(ctx context.Context, mod api.Module, ptr uint32) uint32 {
// Return current timestamp
Expand All @@ -65,11 +51,6 @@ func NewResolverApi(ctx context.Context, runtime wazero.Runtime, wasmBytes []byt
return transferResponse(mod, response)
}).
Export("wasm_msg_host_current_time").
NewFunctionBuilder().
WithFunc(func(ctx context.Context, mod api.Module) uint32 {
return 0
}).
Export("wasm_msg_current_thread_id").
Instantiate(ctx)
if err != nil {
panic(fmt.Sprintf("Failed to register host functions: %v", err))
Expand All @@ -92,9 +73,8 @@ func NewResolverApi(ctx context.Context, runtime wazero.Runtime, wasmBytes []byt
wasmMsgFree := instance.ExportedFunction("wasm_msg_free")
wasmMsgGuestSetResolverState := instance.ExportedFunction("wasm_msg_guest_set_resolver_state")
wasmMsgGuestResolve := instance.ExportedFunction("wasm_msg_guest_resolve")
wasmMsgGuestResolveSimple := instance.ExportedFunction("wasm_msg_guest_resolve_simple")

if wasmMsgAlloc == nil || wasmMsgFree == nil || wasmMsgGuestSetResolverState == nil || wasmMsgGuestResolve == nil || wasmMsgGuestResolveSimple == nil {
if wasmMsgAlloc == nil || wasmMsgFree == nil || wasmMsgGuestSetResolverState == nil || wasmMsgGuestResolve == nil {
panic("Required WASM exports not found")
}

Expand All @@ -106,7 +86,6 @@ func NewResolverApi(ctx context.Context, runtime wazero.Runtime, wasmBytes []byt
wasmMsgFree: wasmMsgFree,
wasmMsgGuestSetResolverState: wasmMsgGuestSetResolverState,
wasmMsgGuestResolve: wasmMsgGuestResolve,
wasmMsgGuestResolveSimple: wasmMsgGuestResolveSimple,
}
}

Expand Down Expand Up @@ -169,26 +148,6 @@ func (r *ResolverApi) Resolve(request *resolver.ResolveFlagsRequest) (*resolver.
return response, nil
}

func (r *ResolverApi) ResolveSimple(request *messages.ResolveSimpleRequest) (*resolver.ResolvedFlag, error) {
ctx := context.Background()
reqPtr := r.transferRequest(request)

results, err := r.wasmMsgGuestResolveSimple.Call(ctx, uint64(reqPtr))
if err != nil {
return nil, fmt.Errorf("failed to call wasm_msg_guest_resolve_simple: %w", err)
}

respPtr := uint32(results[0])

response := &resolver.ResolvedFlag{}
err = r.consumeResponse(respPtr, response)
if err != nil {
return nil, err
}

return response, nil
}

// transferRequest transfers a protobuf message to WASM memory
func (r *ResolverApi) transferRequest(message proto.Message) uint32 {
data := mustMarshal(message)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,18 +9,14 @@
import com.dylibso.chicory.wasm.WasmModule;
import com.dylibso.chicory.wasm.types.FunctionType;
import com.dylibso.chicory.wasm.types.ValType;
import com.google.protobuf.ByteString;
import com.google.protobuf.GeneratedMessage;
import com.google.protobuf.InvalidProtocolBufferException;
import com.google.protobuf.Timestamp;
import com.spotify.confidence.flags.resolver.v1.ResolveFlagsRequest;
import com.spotify.confidence.flags.resolver.v1.ResolveFlagsResponse;
import com.spotify.confidence.flags.resolver.v1.ResolvedFlag;

import rust_guest.Messages.SetResolverStateRequest;
import rust_guest.Messages;
import rust_guest.Messages.ResolveSimpleRequest;
import rust_guest.Types;

import java.util.List;
import java.util.function.Function;
Expand All @@ -37,33 +33,19 @@ public class ResolverApi {
// api
private final ExportFunction wasmMsgGuestSetResolverState;
private final ExportFunction wasmMsgGuestResolve;
private final ExportFunction wasmMsgGuestResolveSimple;

public ResolverApi(WasmModule module) {

instance = Instance.builder(module)
.withImportValues(ImportValues.builder()
.addFunction(createImportFunction("current_time", Messages.Void::parseFrom, this::currentTime))
.addFunction(createImportFunction("log_resolve", Types.LogResolveRequest::parseFrom, this::logResolve))
.addFunction(createImportFunction("log_assign", Types.LogAssignRequest::parseFrom, this::logAssign))
.addFunction(new ImportFunction("wasm_msg", "wasm_msg_current_thread_id", FunctionType.of(List.of(), List.of(ValType.I32)), (instance1, args) -> new long[]{0}))
.build())
.withMachineFactory(MachineFactoryCompiler::compile)
.build();
wasmMsgAlloc = instance.export("wasm_msg_alloc");
wasmMsgFree = instance.export("wasm_msg_free");
wasmMsgGuestSetResolverState = instance.export("wasm_msg_guest_set_resolver_state");
wasmMsgGuestResolve = instance.export("wasm_msg_guest_resolve");
wasmMsgGuestResolveSimple = instance.export("wasm_msg_guest_resolve_simple");
}

private GeneratedMessage logAssign(Types.LogAssignRequest logAssignRequest) {
System.out.println("logAssign");
return Messages.Void.getDefaultInstance();
}

private GeneratedMessage logResolve(Types.LogResolveRequest logResolveRequest) {
return Messages.Void.getDefaultInstance();
}

private Timestamp currentTime(Messages.Void unused) {
Expand All @@ -85,12 +67,6 @@ public ResolveFlagsResponse resolve(ResolveFlagsRequest request) {
return consumeResponse(respPtr, ResolveFlagsResponse::parseFrom);
}

public ResolvedFlag resolve_simple(ResolveSimpleRequest request) {
int reqPtr = transferRequest(request);
int respPtr = (int) wasmMsgGuestResolveSimple.apply(reqPtr)[0];
return consumeResponse(respPtr, ResolvedFlag::parseFrom);
}

private <T extends GeneratedMessage> T consumeResponse(int addr, ParserFn<T> codec) {
try {
Messages.Response response = Messages.Response.parseFrom(consume(addr));
Expand Down
21 changes: 0 additions & 21 deletions wasm/python-host/resolver_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,18 +38,6 @@ def __init__(self, wasm_bytes: bytes):
def _register_host_functions(self):
"""Register host functions that can be called from WASM"""

def log_resolve(ptr: int) -> int:
# Ignore payload; return Void
response = messages_pb2.Response()
response.data = messages_pb2.Void().SerializeToString()
return self._transfer_response(response)

def log_assign(ptr: int) -> int:
# Ignore payload; return Void
response = messages_pb2.Response()
response.data = messages_pb2.Void().SerializeToString()
return self._transfer_response(response)

def current_time(ptr: int) -> int:
"""Host function to return current timestamp"""
try:
Expand All @@ -75,20 +63,11 @@ def current_time(ptr: int) -> int:
# Create function type: takes one i32 parameter, returns one i32
func_type = FuncType([ValType.i32()], [ValType.i32()])
host_func_time = Func(self.store, func_type, current_time)
host_func_log_resolve = Func(self.store, func_type, log_resolve)
host_func_log_assign = Func(self.store, func_type, log_assign)

linker = Linker(self.store.engine)

# Define the import with module and name
linker.define(self.store, "wasm_msg", "wasm_msg_host_current_time", host_func_time)
linker.define(self.store, "wasm_msg", "wasm_msg_host_log_resolve", host_func_log_resolve)
linker.define(self.store, "wasm_msg", "wasm_msg_host_log_assign", host_func_log_assign)

# Optional: current thread id function
def current_thread_id() -> int:
return 0
linker.define(self.store, "wasm_msg", "wasm_msg_current_thread_id", Func(self.store, FuncType([], [ValType.i32()]), current_thread_id))

# Instantiate the module with imports
self.instance = linker.instantiate(self.store, self.module)
Expand Down
51 changes: 23 additions & 28 deletions wasm/rust-guest/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@ use rand::distr::Alphanumeric;
use rand::distr::SampleString;
use rand::rngs::SmallRng;
use rand::SeedableRng;
use wasm_msg;
use wasm_msg::wasm_msg_guest;
use wasm_msg::wasm_msg_host;
use wasm_msg::WasmResult;
Expand All @@ -28,23 +27,26 @@ use confidence_resolver::{
proto::{
confidence::flags::admin::v1::ResolverState as ResolverStatePb,
confidence::flags::resolver::v1::{
ResolveFlagsRequest, ResolveFlagsResponse, ResolveWithStickyResponse, ResolvedFlag, Sdk,
ResolveFlagsRequest, ResolveFlagsResponse, ResolveWithStickyResponse, Sdk,
},
google::{Struct, Timestamp},
},
Client, FlagToApply, Host, ResolveReason, ResolvedValue, ResolverState,
};
use proto::{ResolveSimpleRequest, Void};
use proto::Void;

impl Into<proto::FallthroughAssignment>
for confidence_resolver::proto::confidence::flags::resolver::v1::events::FallthroughAssignment
impl
From<confidence_resolver::proto::confidence::flags::resolver::v1::events::FallthroughAssignment>
for proto::FallthroughAssignment
{
fn into(self) -> proto::FallthroughAssignment {
fn from(
val: confidence_resolver::proto::confidence::flags::resolver::v1::events::FallthroughAssignment,
) -> Self {
proto::FallthroughAssignment {
rule: self.rule,
assignment_id: self.assignment_id,
targeting_key: self.targeting_key,
targeting_key_selector: self.targeting_key_selector,
rule: val.rule,
assignment_id: val.assignment_id,
targeting_key: val.targeting_key,
targeting_key_selector: val.targeting_key_selector,
}
}
}
Expand All @@ -63,16 +65,17 @@ thread_local! {
});
}

impl<'a> Into<proto::ResolvedValue> for &ResolvedValue<'a> {
fn into(self) -> proto::ResolvedValue {
impl<'a> From<&ResolvedValue<'a>> for proto::ResolvedValue {
fn from(val: &ResolvedValue<'a>) -> Self {
proto::ResolvedValue {
flag: Some(proto::Flag {
name: self.flag.name.clone(),
name: val.flag.name.clone(),
}),
reason: convert_reason(self.reason.clone()),
assignment_match: match &self.assignment_match {
None => None,
Some(am) => Some(proto::AssignmentMatch {
reason: convert_reason(val.reason),
assignment_match: val
.assignment_match
.as_ref()
.map(|am| proto::AssignmentMatch {
matched_rule: Some(proto::MatchedRule {
name: am.rule.clone().name,
}),
Expand All @@ -84,8 +87,7 @@ impl<'a> Into<proto::ResolvedValue> for &ResolvedValue<'a> {
}),
assignment_id: am.assignment_id.to_string(),
}),
},
fallthrough_rules: self
fallthrough_rules: val
.fallthrough_rules
.iter()
.map(|fr| proto::FallthroughRule {
Expand Down Expand Up @@ -185,21 +187,14 @@ wasm_msg_guest! {
let resolve_request = &request.resolve_request.clone().unwrap();
let evaluation_context = resolve_request.evaluation_context.clone().unwrap();
let resolver = resolver_state.get_resolver::<WasmHost>(resolve_request.client_secret.as_str(), evaluation_context, &ENCRYPTION_KEY)?;
resolver.resolve_flags_sticky(&request).into()
resolver.resolve_flags_sticky(&request)
}

fn resolve(request: ResolveFlagsRequest) -> WasmResult<ResolveFlagsResponse> {
let resolver_state = get_resolver_state()?;
let evaluation_context = request.evaluation_context.as_ref().cloned().unwrap_or_default();
let resolver = resolver_state.get_resolver::<WasmHost>(&request.client_secret, evaluation_context, &ENCRYPTION_KEY)?;
resolver.resolve_flags(&request).into()
}
fn resolve_simple(request: ResolveSimpleRequest) -> WasmResult<ResolvedFlag> {
let resolver_state = get_resolver_state()?;
let evaluation_context = request.evaluation_context.as_ref().cloned().unwrap_or_default();
let resolver = resolver_state.get_resolver::<WasmHost>(&request.client_secret, evaluation_context, &ENCRYPTION_KEY).unwrap();
let resolve_result = resolver.resolve_flag_name(&request.name)?;
Ok((&resolve_result.resolved_value).into())
resolver.resolve_flags(&request)
}
fn flush_logs(_request:Void) -> WasmResult<WriteFlagLogsRequest> {
let response = LOGGER.checkpoint();
Expand Down