diff --git a/android/build.gradle b/android/build.gradle index efccc52f2..7d4ed01c4 100644 --- a/android/build.gradle +++ b/android/build.gradle @@ -142,6 +142,7 @@ android { sourceCompatibility JavaVersion.VERSION_1_8 targetCompatibility JavaVersion.VERSION_1_8 } + } repositories { diff --git a/common/rnexecutorch/RnExecutorchInstaller.cpp b/common/rnexecutorch/RnExecutorchInstaller.cpp index 2aef62a04..5fb2eb119 100644 --- a/common/rnexecutorch/RnExecutorchInstaller.cpp +++ b/common/rnexecutorch/RnExecutorchInstaller.cpp @@ -2,7 +2,6 @@ #include #include -#include #include namespace rnexecutorch { @@ -27,8 +26,8 @@ jsi::Function RnExecutorchInstaller::loadStyleTransfer( auto styleTransferPtr = std::make_shared(source, &runtime); auto styleTransferHostObject = - std::make_shared>( - styleTransferPtr, &runtime, jsCallInvoker); + std::make_shared>(styleTransferPtr, + jsCallInvoker); return jsi::Object::createFromHostObject(runtime, styleTransferHostObject); diff --git a/common/rnexecutorch/host_objects/ModelHostObject.h b/common/rnexecutorch/host_objects/ModelHostObject.h index de58b3738..4e31c8583 100644 --- a/common/rnexecutorch/host_objects/ModelHostObject.h +++ b/common/rnexecutorch/host_objects/ModelHostObject.h @@ -5,24 +5,26 @@ #include #include +#include + #include #include #include -#include +#include namespace rnexecutorch { template class ModelHostObject : public JsiHostObject { public: - explicit ModelHostObject( - const std::shared_ptr &model, jsi::Runtime *runtime, - const std::shared_ptr &callInvoker) - : model(model), promiseVendor(runtime, callInvoker) { - addFunctions(JSI_EXPORT_FUNCTION(ModelHostObject, forward)); + explicit ModelHostObject(const std::shared_ptr &model, + std::shared_ptr callInvoker) + : model(model), callInvoker(callInvoker) { + addFunctions(JSI_EXPORT_FUNCTION(ModelHostObject, forward)); } JSI_HOST_FUNCTION(forward) { - auto promise = promiseVendor.createPromise( + auto promise = Promise::createPromise( + runtime, callInvoker, [this, count, args, &runtime](std::shared_ptr promise) { constexpr std::size_t forwardArgCount = jsiconversion::getArgumentCount(&Model::forward); @@ -32,40 +34,56 @@ template class ModelHostObject : public JsiHostObject { errorMessage, sizeof(errorMessage), "Argument count mismatch, was expecting: %zu but got: %zu", forwardArgCount, count); - promise->reject(errorMessage); return; } - // Do the asynchronous work - std::thread([this, promise = std::move(promise), args, &runtime]() { - try { - auto argsConverted = jsiconversion::createArgsTupleFromJsi( - &Model::forward, args, runtime); - auto result = std::apply(std::bind_front(&Model::forward, model), - argsConverted); + try { + auto argsConverted = jsiconversion::createArgsTupleFromJsi( + &Model::forward, args, runtime); - promise->resolve([result = - std::move(result)](jsi::Runtime &runtime) { - return jsiconversion::getJsiValue(std::move(result), runtime); - }); - } catch (const std::runtime_error &e) { - // This catch should be merged with the next one - // (std::runtime_error inherits from std::exception) HOWEVER react - // native has broken RTTI which breaks proper exception type - // checking. Remove when the following change is present in our - // version: - // https://github.com/facebook/react-native/commit/3132cc88dd46f95898a756456bebeeb6c248f20e - promise->reject(e.what()); - return; - } catch (const std::exception &e) { - promise->reject(e.what()); - return; - } catch (...) { - promise->reject("Unknown error"); - return; - } - }).detach(); + // We need to dispatch a thread if we want the forward to be + // asynchronous. In this thread all accesses to jsi::Runtime need to + // be done via the callInvoker. + std::thread([this, promise, + argsConverted = std::move(argsConverted)]() { + try { + auto result = std::apply( + std::bind_front(&Model::forward, model), argsConverted); + + callInvoker->invokeAsync([promise, result = std::move(result)]( + jsi::Runtime &runtime) { + promise->resolve( + jsiconversion::getJsiValue(std::move(result), runtime)); + }); + } catch (const std::runtime_error &e) { + // This catch should be merged with the next two + // (std::runtime_error and jsi::JSError inherits from + // std::exception) HOWEVER react native has broken RTTI which + // breaks proper exception type checking. Remove when the + // following change is present in our version: + // https://github.com/facebook/react-native/commit/3132cc88dd46f95898a756456bebeeb6c248f20e + callInvoker->invokeAsync( + [&e, promise]() { promise->reject(e.what()); }); + return; + } catch (const jsi::JSError &e) { + callInvoker->invokeAsync( + [&e, promise]() { promise->reject(e.what()); }); + return; + } catch (const std::exception &e) { + callInvoker->invokeAsync( + [&e, promise]() { promise->reject(e.what()); }); + return; + } catch (...) { + callInvoker->invokeAsync( + [promise]() { promise->reject("Unknown error"); }); + return; + } + }).detach(); + } catch (...) { + promise->reject( + "Couldn't parse JS arguments in native forward function"); + } }); return promise; @@ -73,7 +91,7 @@ template class ModelHostObject : public JsiHostObject { private: std::shared_ptr model; - PromiseVendor promiseVendor; + std::shared_ptr callInvoker; }; } // namespace rnexecutorch \ No newline at end of file diff --git a/common/rnexecutorch/jsi/JsiPromise.cpp b/common/rnexecutorch/jsi/JsiPromise.cpp deleted file mode 100644 index 4487d0073..000000000 --- a/common/rnexecutorch/jsi/JsiPromise.cpp +++ /dev/null @@ -1,60 +0,0 @@ -#include "JsiPromise.h" - -namespace rnexecutorch { - -using namespace facebook; - -jsi::Value PromiseVendor::createPromise( - const std::function)> &function) { - if (runtime_ == nullptr) { - throw std::runtime_error("Runtime was null!"); - } - - auto &runtime = *runtime_; - auto callInvoker = callInvoker_; - - // get Promise constructor - auto promiseCtor = runtime.global().getPropertyAsFunction(runtime, "Promise"); - - // create a "run" function (first Promise arg) - auto runPromise = jsi::Function::createFromHostFunction( - runtime, jsi::PropNameID::forUtf8(runtime, "runPromise"), 2, - [callInvoker, - function](jsi::Runtime &runtime, const jsi::Value &thisValue, - const jsi::Value *arguments, size_t count) -> jsi::Value { - auto resolveLocal = arguments[0].asObject(runtime).asFunction(runtime); - auto resolve = std::make_shared(std::move(resolveLocal)); - auto rejectLocal = arguments[1].asObject(runtime).asFunction(runtime); - auto reject = std::make_shared(std::move(rejectLocal)); - - auto resolveWrapper = - [resolve, &runtime, callInvoker]( - const std::function &resolver) - -> void { - callInvoker->invokeAsync([resolve, &runtime, resolver]() -> void { - auto valueShared = std::make_shared(resolver(runtime)); - - resolve->call(runtime, *valueShared); - }); - }; - - auto rejectWrapper = [reject, &runtime, callInvoker]( - const std::string &errorMessage) -> void { - callInvoker->invokeAsync([reject, &runtime, errorMessage]() -> void { - auto error = jsi::JSError(runtime, errorMessage); - auto errorShared = std::make_shared(error); - reject->call(runtime, errorShared->value()); - }); - }; - - auto promise = std::make_shared(resolveWrapper, rejectWrapper); - function(promise); - - return jsi::Value::undefined(); - }); - - // return new Promise((resolve, reject) => ...) - return promiseCtor.callAsConstructor(runtime, runPromise); -} - -} // namespace rnexecutorch \ No newline at end of file diff --git a/common/rnexecutorch/jsi/JsiPromise.h b/common/rnexecutorch/jsi/JsiPromise.h deleted file mode 100644 index 3109a9982..000000000 --- a/common/rnexecutorch/jsi/JsiPromise.h +++ /dev/null @@ -1,48 +0,0 @@ -#pragma once -// Adapted from https://github.com/software-mansion/react-native-audio-api - -#include -#include -#include -#include -#include -#include - -namespace rnexecutorch { - -using namespace facebook; - -class Promise { -public: - Promise(std::function)> - resolve, - std::function reject) - : resolve_(std::move(resolve)), reject_(std::move(reject)) {} - - void resolve(const std::function &resolver) { - resolve_(std::forward>( - resolver)); - } - - void reject(const std::string &errorMessage) { reject_(errorMessage); } - -private: - std::function)> resolve_; - std::function reject_; -}; - -class PromiseVendor { -public: - PromiseVendor(jsi::Runtime *runtime, - const std::shared_ptr &callInvoker) - : runtime_(runtime), callInvoker_(callInvoker) {} - - jsi::Value - createPromise(const std::function)> &function); - -private: - jsi::Runtime *runtime_; - std::shared_ptr callInvoker_; -}; - -} // namespace rnexecutorch \ No newline at end of file diff --git a/common/rnexecutorch/jsi/Promise.cpp b/common/rnexecutorch/jsi/Promise.cpp new file mode 100644 index 000000000..f08f26654 --- /dev/null +++ b/common/rnexecutorch/jsi/Promise.cpp @@ -0,0 +1,20 @@ +#include "Promise.h" + +namespace rnexecutorch { + +Promise::Promise(jsi::Runtime &runtime, + std::shared_ptr callInvoker, + jsi::Value resolver, jsi::Value rejecter) + : runtime(runtime), callInvoker(callInvoker), + _resolver(std::move(resolver)), _rejecter(std::move(rejecter)) {} + +void Promise::resolve(jsi::Value &&result) { + _resolver.asObject(runtime).asFunction(runtime).call(runtime, result); +} + +void Promise::reject(std::string message) { + jsi::JSError error(runtime, message); + _rejecter.asObject(runtime).asFunction(runtime).call(runtime, error.value()); +} + +} // namespace rnexecutorch \ No newline at end of file diff --git a/common/rnexecutorch/jsi/Promise.h b/common/rnexecutorch/jsi/Promise.h new file mode 100644 index 000000000..4dba08891 --- /dev/null +++ b/common/rnexecutorch/jsi/Promise.h @@ -0,0 +1,69 @@ +#pragma once + +#include +#include + +#include +#include + +namespace rnexecutorch { + +using namespace facebook; + +class Promise; + +template +concept PromiseRunFn = + std::invocable> && + std::same_as>, void>; + +class Promise { +public: + Promise(jsi::Runtime &runtime, + std::shared_ptr callInvoker, jsi::Value resolver, + jsi::Value rejecter); + + Promise(const Promise &) = delete; + Promise &operator=(const Promise &) = delete; + + void resolve(jsi::Value &&result); + void reject(std::string error); + + /** + Creates a new promise and runs the supplied "run" function that takes this + promise. We use a template for the function type to not use std::function + and be able to bind a lambda. + */ + template + static jsi::Value + createPromise(jsi::Runtime &runtime, + std::shared_ptr callInvoker, Fn &&run) { + // Get Promise ctor from global + auto promiseCtor = + runtime.global().getPropertyAsFunction(runtime, "Promise"); + + auto promiseCallback = jsi::Function::createFromHostFunction( + runtime, jsi::PropNameID::forUtf8(runtime, "PromiseCallback"), 2, + [run = std::move(run), + callInvoker](jsi::Runtime &runtime, const jsi::Value &thisValue, + const jsi::Value *arguments, size_t count) -> jsi::Value { + // Call function + auto promise = std::make_shared( + runtime, callInvoker, arguments[0].asObject(runtime), + arguments[1].asObject(runtime)); + run(promise); + + return jsi::Value::undefined(); + }); + + return promiseCtor.callAsConstructor(runtime, promiseCallback); + } + +private: + jsi::Runtime &runtime; + std::shared_ptr callInvoker; + jsi::Value _resolver; + jsi::Value _rejecter; +}; + +} // namespace rnexecutorch \ No newline at end of file diff --git a/common/rnexecutorch/models/BaseModel.cpp b/common/rnexecutorch/models/BaseModel.cpp index b44b83d74..5da87eeee 100644 --- a/common/rnexecutorch/models/BaseModel.cpp +++ b/common/rnexecutorch/models/BaseModel.cpp @@ -1,4 +1,4 @@ -#include +#include "BaseModel.h" #include diff --git a/common/rnexecutorch/models/StyleTransfer.cpp b/common/rnexecutorch/models/StyleTransfer.cpp index 3aacdb6b8..54aa6d318 100644 --- a/common/rnexecutorch/models/StyleTransfer.cpp +++ b/common/rnexecutorch/models/StyleTransfer.cpp @@ -1,13 +1,13 @@ #include "StyleTransfer.h" +#include +#include + #include #include #include -#include -#include - namespace rnexecutorch { using namespace facebook; using executorch::extension::Module;