Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 17 additions & 13 deletions common/rnexecutorch/host_objects/ModelHostObject.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,37 +19,42 @@ template <typename Model> class ModelHostObject : public JsiHostObject {
explicit ModelHostObject(const std::shared_ptr<Model> &model,
std::shared_ptr<react::CallInvoker> callInvoker)
: model(model), callInvoker(callInvoker) {
addFunctions(JSI_EXPORT_FUNCTION(ModelHostObject<Model>, forward));
addFunctions(JSI_EXPORT_FUNCTION(ModelHostObject<Model>,
promiseHostFunction<&Model::forward>,
"forward"));
}

JSI_HOST_FUNCTION(forward) {
// A generic host function that resolves a promise with a result of a
// function. JSI arguments are converted to the types provided in the function
// signature, and the return value is converted back to JSI before resolving.
template <auto FnPtr> JSI_HOST_FUNCTION(promiseHostFunction) {
auto promise = Promise::createPromise(
runtime, callInvoker,
[this, count, args, &runtime](std::shared_ptr<Promise> promise) {
constexpr std::size_t forwardArgCount =
jsiconversion::getArgumentCount(&Model::forward);
if (forwardArgCount != count) {
constexpr std::size_t functionArgCount =
jsiconversion::getArgumentCount(FnPtr);
if (functionArgCount != count) {
char errorMessage[100];
std::snprintf(
errorMessage, sizeof(errorMessage),
"Argument count mismatch, was expecting: %zu but got: %zu",
forwardArgCount, count);
functionArgCount, count);
promise->reject(errorMessage);
return;
}

try {
auto argsConverted = jsiconversion::createArgsTupleFromJsi(
&Model::forward, args, runtime);
auto argsConverted =
jsiconversion::createArgsTupleFromJsi(FnPtr, args, runtime);

// We need to dispatch a thread if we want the forward to be
// We need to dispatch a thread if we want the function to be
// asynchronous. In this thread all accesses to jsi::Runtime need to
// be done via the callInvoker.
std::thread([this, promise,
argsConverted = std::move(argsConverted)]() {
try {
auto result = std::apply(
std::bind_front(&Model::forward, model), argsConverted);
auto result =
std::apply(std::bind_front(FnPtr, model), argsConverted);

callInvoker->invokeAsync([promise, result = std::move(result)](
jsi::Runtime &runtime) {
Expand Down Expand Up @@ -81,8 +86,7 @@ template <typename Model> class ModelHostObject : public JsiHostObject {
}
}).detach();
} catch (...) {
promise->reject(
"Couldn't parse JS arguments in native forward function");
promise->reject("Couldn't parse JS arguments in a native function");
}
});

Expand Down
4 changes: 2 additions & 2 deletions common/rnexecutorch/jsi/JsiHostObject.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,9 @@
jsi::Value NAME(jsi::Runtime &runtime, const jsi::Value &thisValue, \
const jsi::Value *args, size_t count)

#define JSI_EXPORT_FUNCTION(CLASS, FUNCTION) \
#define JSI_EXPORT_FUNCTION(CLASS, FUNCTION, NAME) \
std::make_pair( \
std::string(#FUNCTION), \
NAME, \
static_cast<jsi::Value (JsiHostObject::*)( \
jsi::Runtime &, const jsi::Value &, const jsi::Value *, size_t)>( \
&CLASS::FUNCTION))
Expand Down