Skip to content
1 change: 1 addition & 0 deletions android/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,7 @@ android {
sourceCompatibility JavaVersion.VERSION_1_8
targetCompatibility JavaVersion.VERSION_1_8
}

}

repositories {
Expand Down
5 changes: 2 additions & 3 deletions common/rnexecutorch/RnExecutorchInstaller.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

#include <rnexecutorch/host_objects/JsiConversions.h>
#include <rnexecutorch/host_objects/ModelHostObject.h>
#include <rnexecutorch/jsi/JsiPromise.h>
#include <rnexecutorch/models/StyleTransfer.h>

namespace rnexecutorch {
Expand All @@ -27,8 +26,8 @@ jsi::Function RnExecutorchInstaller::loadStyleTransfer(
auto styleTransferPtr =
std::make_shared<StyleTransfer>(source, &runtime);
auto styleTransferHostObject =
std::make_shared<ModelHostObject<StyleTransfer>>(
styleTransferPtr, &runtime, jsCallInvoker);
std::make_shared<ModelHostObject<StyleTransfer>>(styleTransferPtr,
jsCallInvoker);

return jsi::Object::createFromHostObject(runtime,
styleTransferHostObject);
Expand Down
92 changes: 55 additions & 37 deletions common/rnexecutorch/host_objects/ModelHostObject.h
Original file line number Diff line number Diff line change
Expand Up @@ -5,24 +5,26 @@
#include <tuple>
#include <vector>

#include <ReactCommon/CallInvoker.h>

#include <rnexecutorch/Log.h>
#include <rnexecutorch/host_objects/JsiConversions.h>
#include <rnexecutorch/jsi/JsiHostObject.h>
#include <rnexecutorch/jsi/JsiPromise.h>
#include <rnexecutorch/jsi/Promise.h>

namespace rnexecutorch {

template <typename Model> class ModelHostObject : public JsiHostObject {
public:
explicit ModelHostObject(
const std::shared_ptr<Model> &model, jsi::Runtime *runtime,
const std::shared_ptr<react::CallInvoker> &callInvoker)
: model(model), promiseVendor(runtime, callInvoker) {
addFunctions(JSI_EXPORT_FUNCTION(ModelHostObject, forward));
explicit ModelHostObject(const std::shared_ptr<Model> &model,
std::shared_ptr<react::CallInvoker> callInvoker)
: model(model), callInvoker(callInvoker) {
addFunctions(JSI_EXPORT_FUNCTION(ModelHostObject<Model>, forward));
}

JSI_HOST_FUNCTION(forward) {
auto promise = promiseVendor.createPromise(
auto promise = Promise::createPromise(
runtime, callInvoker,
[this, count, args, &runtime](std::shared_ptr<Promise> promise) {
constexpr std::size_t forwardArgCount =
jsiconversion::getArgumentCount(&Model::forward);
Expand All @@ -32,48 +34,64 @@ template <typename Model> class ModelHostObject : public JsiHostObject {
errorMessage, sizeof(errorMessage),
"Argument count mismatch, was expecting: %zu but got: %zu",
forwardArgCount, count);

promise->reject(errorMessage);
return;
}

// Do the asynchronous work
std::thread([this, promise = std::move(promise), args, &runtime]() {
try {
auto argsConverted = jsiconversion::createArgsTupleFromJsi(
&Model::forward, args, runtime);
auto result = std::apply(std::bind_front(&Model::forward, model),
argsConverted);
try {
auto argsConverted = jsiconversion::createArgsTupleFromJsi(
&Model::forward, args, runtime);

promise->resolve([result =
std::move(result)](jsi::Runtime &runtime) {
return jsiconversion::getJsiValue(std::move(result), runtime);
});
} catch (const std::runtime_error &e) {
// This catch should be merged with the next one
// (std::runtime_error inherits from std::exception) HOWEVER react
// native has broken RTTI which breaks proper exception type
// checking. Remove when the following change is present in our
// version:
// https://github.com/facebook/react-native/commit/3132cc88dd46f95898a756456bebeeb6c248f20e
promise->reject(e.what());
return;
} catch (const std::exception &e) {
promise->reject(e.what());
return;
} catch (...) {
promise->reject("Unknown error");
return;
}
}).detach();
// We need to dispatch a thread if we want the forward to be
// asynchronous. In this thread all accesses to jsi::Runtime need to
// be done via the callInvoker.
std::thread([this, promise,
argsConverted = std::move(argsConverted)]() {
try {
auto result = std::apply(
std::bind_front(&Model::forward, model), argsConverted);

callInvoker->invokeAsync([promise, result = std::move(result)](
jsi::Runtime &runtime) {
promise->resolve(
jsiconversion::getJsiValue(std::move(result), runtime));
});
} catch (const std::runtime_error &e) {
// This catch should be merged with the next two
// (std::runtime_error and jsi::JSError inherits from
// std::exception) HOWEVER react native has broken RTTI which
// breaks proper exception type checking. Remove when the
// following change is present in our version:
// https://github.com/facebook/react-native/commit/3132cc88dd46f95898a756456bebeeb6c248f20e
callInvoker->invokeAsync(
[&e, promise]() { promise->reject(e.what()); });
return;
} catch (const jsi::JSError &e) {
callInvoker->invokeAsync(
[&e, promise]() { promise->reject(e.what()); });
return;
} catch (const std::exception &e) {
callInvoker->invokeAsync(
[&e, promise]() { promise->reject(e.what()); });
return;
} catch (...) {
callInvoker->invokeAsync(
[promise]() { promise->reject("Unknown error"); });
return;
}
}).detach();
} catch (...) {
promise->reject(
"Couldn't parse JS arguments in native forward function");
}
});

return promise;
}

private:
std::shared_ptr<Model> model;
PromiseVendor promiseVendor;
std::shared_ptr<react::CallInvoker> callInvoker;
};

} // namespace rnexecutorch
60 changes: 0 additions & 60 deletions common/rnexecutorch/jsi/JsiPromise.cpp

This file was deleted.

48 changes: 0 additions & 48 deletions common/rnexecutorch/jsi/JsiPromise.h

This file was deleted.

20 changes: 20 additions & 0 deletions common/rnexecutorch/jsi/Promise.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
#include "Promise.h"

namespace rnexecutorch {

Promise::Promise(jsi::Runtime &runtime,
std::shared_ptr<react::CallInvoker> callInvoker,
jsi::Value resolver, jsi::Value rejecter)
: runtime(runtime), callInvoker(callInvoker),
_resolver(std::move(resolver)), _rejecter(std::move(rejecter)) {}

void Promise::resolve(jsi::Value &&result) {
_resolver.asObject(runtime).asFunction(runtime).call(runtime, result);
}

void Promise::reject(std::string message) {
jsi::JSError error(runtime, message);
_rejecter.asObject(runtime).asFunction(runtime).call(runtime, error.value());
}

} // namespace rnexecutorch
69 changes: 69 additions & 0 deletions common/rnexecutorch/jsi/Promise.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
#pragma once

#include <memory>
#include <string>

#include <ReactCommon/CallInvoker.h>
#include <jsi/jsi.h>

namespace rnexecutorch {

using namespace facebook;

class Promise;

template <typename T>
concept PromiseRunFn =
std::invocable<T, std::shared_ptr<Promise>> &&
std::same_as<std::invoke_result_t<T, std::shared_ptr<Promise>>, void>;

class Promise {
public:
Promise(jsi::Runtime &runtime,
std::shared_ptr<react::CallInvoker> callInvoker, jsi::Value resolver,
jsi::Value rejecter);

Promise(const Promise &) = delete;
Promise &operator=(const Promise &) = delete;

void resolve(jsi::Value &&result);
void reject(std::string error);

/**
Creates a new promise and runs the supplied "run" function that takes this
promise. We use a template for the function type to not use std::function
and be able to bind a lambda.
*/
template <PromiseRunFn Fn>
static jsi::Value
createPromise(jsi::Runtime &runtime,
std::shared_ptr<react::CallInvoker> callInvoker, Fn &&run) {
// Get Promise ctor from global
auto promiseCtor =
runtime.global().getPropertyAsFunction(runtime, "Promise");

auto promiseCallback = jsi::Function::createFromHostFunction(
runtime, jsi::PropNameID::forUtf8(runtime, "PromiseCallback"), 2,
[run = std::move(run),
callInvoker](jsi::Runtime &runtime, const jsi::Value &thisValue,
const jsi::Value *arguments, size_t count) -> jsi::Value {
// Call function
auto promise = std::make_shared<Promise>(
runtime, callInvoker, arguments[0].asObject(runtime),
arguments[1].asObject(runtime));
run(promise);

return jsi::Value::undefined();
});

return promiseCtor.callAsConstructor(runtime, promiseCallback);
}

private:
jsi::Runtime &runtime;
std::shared_ptr<react::CallInvoker> callInvoker;
jsi::Value _resolver;
jsi::Value _rejecter;
};

} // namespace rnexecutorch
2 changes: 1 addition & 1 deletion common/rnexecutorch/models/BaseModel.cpp
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
#include <rnexecutorch/models/BaseModel.h>
#include "BaseModel.h"

#include <rnexecutorch/Log.h>

Expand Down
6 changes: 3 additions & 3 deletions common/rnexecutorch/models/StyleTransfer.cpp
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
#include "StyleTransfer.h"

#include <rnexecutorch/Log.h>
#include <rnexecutorch/data_processing/ImageProcessing.h>

#include <span>

#include <executorch/extension/tensor/tensor.h>
#include <opencv2/opencv.hpp>

#include <rnexecutorch/Log.h>
#include <rnexecutorch/data_processing/ImageProcessing.h>

namespace rnexecutorch {
using namespace facebook;
using executorch::extension::Module;
Expand Down