diff --git a/android/src/main/cpp/ETInstallerModule.cpp b/android/src/main/cpp/ETInstallerModule.cpp index bdb407a49..d69af8634 100644 --- a/android/src/main/cpp/ETInstallerModule.cpp +++ b/android/src/main/cpp/ETInstallerModule.cpp @@ -53,6 +53,10 @@ void ETInstallerModule::injectJSIBindings() { jbyteArray byteData = (jbyteArray)env->CallStaticObjectMethod(cls, method, jUrl); + if (env->IsSameObject(byteData, NULL)) { + throw std::runtime_error("Error fetching data from a url"); + } + int size = env->GetArrayLength(byteData); jbyte *bytes = env->GetByteArrayElements(byteData, JNI_FALSE); std::byte *dataBytePtr = reinterpret_cast(bytes); diff --git a/android/src/main/java/com/swmansion/rnexecutorch/ETInstaller.kt b/android/src/main/java/com/swmansion/rnexecutorch/ETInstaller.kt index 96bce89d1..acc43c0a9 100644 --- a/android/src/main/java/com/swmansion/rnexecutorch/ETInstaller.kt +++ b/android/src/main/java/com/swmansion/rnexecutorch/ETInstaller.kt @@ -21,16 +21,20 @@ class ETInstaller( @JvmStatic @DoNotStrip @Throws(Exception::class) - fun fetchByteDataFromUrl(source: String): ByteArray { - val url = URL(source) - val connection = url.openConnection() - connection.connect() + fun fetchByteDataFromUrl(source: String): ByteArray? { + try { + val url = URL(source) + val connection = url.openConnection() + connection.connect() - val inputStream: InputStream = connection.getInputStream() - val data = inputStream.readBytes() - inputStream.close() + val inputStream: InputStream = connection.getInputStream() + val data = inputStream.readBytes() + inputStream.close() - return data + return data + } catch (exception: Throwable) { + return null + } } } diff --git a/common/rnexecutorch/data_processing/ImageProcessing.cpp b/common/rnexecutorch/data_processing/ImageProcessing.cpp index 5b47409ee..20461d169 100644 --- a/common/rnexecutorch/data_processing/ImageProcessing.cpp +++ b/common/rnexecutorch/data_processing/ImageProcessing.cpp @@ -97,12 +97,14 @@ cv::Mat readImage(const std::string &imageURI) { // local file auto url = ada::parse(imageURI); image = cv::imread(std::string{url->get_pathname()}, cv::IMREAD_COLOR); - } else { + } else if (imageURI.starts_with("http")) { // remote file std::vector imageData = fetchUrlFunc(imageURI); image = cv::imdecode( cv::Mat(1, imageData.size(), CV_8UC1, (void *)imageData.data()), cv::IMREAD_COLOR); + } else { + throw std::runtime_error("Read image error: unknown protocol"); } if (image.empty()) { diff --git a/common/rnexecutorch/host_objects/ModelHostObject.h b/common/rnexecutorch/host_objects/ModelHostObject.h index 2d5f8d6c8..de58b3738 100644 --- a/common/rnexecutorch/host_objects/ModelHostObject.h +++ b/common/rnexecutorch/host_objects/ModelHostObject.h @@ -1,3 +1,5 @@ +#pragma once + #include #include #include @@ -20,38 +22,48 @@ template class ModelHostObject : public JsiHostObject { } JSI_HOST_FUNCTION(forward) { - auto promise = promiseVendor.createPromise( [this, count, args, &runtime](std::shared_ptr promise) { - std::thread([this, promise = std::move(promise), count, args, - &runtime]() { - constexpr std::size_t forwardArgCount = - jsiconversion::getArgumentCount(&Model::forward); - if (forwardArgCount != count) { - char errorMessage[100]; - std::snprintf( - errorMessage, sizeof(errorMessage), - "Argument count mismatch, was expecting: %zu but got: %zu", - forwardArgCount, count); + constexpr std::size_t forwardArgCount = + jsiconversion::getArgumentCount(&Model::forward); + if (forwardArgCount != count) { + char errorMessage[100]; + std::snprintf( + errorMessage, sizeof(errorMessage), + "Argument count mismatch, was expecting: %zu but got: %zu", + forwardArgCount, count); - promise->reject(errorMessage); - return; - } + promise->reject(errorMessage); + return; + } + // Do the asynchronous work + std::thread([this, promise = std::move(promise), args, &runtime]() { try { auto argsConverted = jsiconversion::createArgsTupleFromJsi( &Model::forward, args, runtime); - promise->resolve([this, argsConverted = std::move(argsConverted)]( - jsi::Runtime &runtime) { - auto result = std::apply( - std::bind_front(&Model::forward, model), argsConverted); - auto resultValue = - jsiconversion::getJsiValue(std::move(result), runtime); - return resultValue; + auto result = std::apply(std::bind_front(&Model::forward, model), + argsConverted); + + promise->resolve([result = + std::move(result)](jsi::Runtime &runtime) { + return jsiconversion::getJsiValue(std::move(result), runtime); }); + } catch (const std::runtime_error &e) { + // This catch should be merged with the next one + // (std::runtime_error inherits from std::exception) HOWEVER react + // native has broken RTTI which breaks proper exception type + // checking. Remove when the following change is present in our + // version: + // https://github.com/facebook/react-native/commit/3132cc88dd46f95898a756456bebeeb6c248f20e + promise->reject(e.what()); + return; } catch (const std::exception &e) { promise->reject(e.what()); return; + } catch (...) { + promise->reject("Unknown error"); + return; } }).detach(); }); diff --git a/common/rnexecutorch/jsi/JsiPromise.cpp b/common/rnexecutorch/jsi/JsiPromise.cpp index 2043ac92a..4487d0073 100644 --- a/common/rnexecutorch/jsi/JsiPromise.cpp +++ b/common/rnexecutorch/jsi/JsiPromise.cpp @@ -40,9 +40,9 @@ jsi::Value PromiseVendor::createPromise( auto rejectWrapper = [reject, &runtime, callInvoker]( const std::string &errorMessage) -> void { - auto error = jsi::JSError(runtime, errorMessage); - auto errorShared = std::make_shared(error); - callInvoker->invokeAsync([reject, &runtime, errorShared]() -> void { + callInvoker->invokeAsync([reject, &runtime, errorMessage]() -> void { + auto error = jsi::JSError(runtime, errorMessage); + auto errorShared = std::make_shared(error); reject->call(runtime, errorShared->value()); }); }; diff --git a/common/rnexecutorch/models/StyleTransfer.cpp b/common/rnexecutorch/models/StyleTransfer.cpp index f585ca0d6..3aacdb6b8 100644 --- a/common/rnexecutorch/models/StyleTransfer.cpp +++ b/common/rnexecutorch/models/StyleTransfer.cpp @@ -2,6 +2,7 @@ #include +#include #include #include diff --git a/ios/RnExecutorch/ETInstaller.mm b/ios/RnExecutorch/ETInstaller.mm index 7eb1be387..dcb40b29d 100644 --- a/ios/RnExecutorch/ETInstaller.mm +++ b/ios/RnExecutorch/ETInstaller.mm @@ -5,6 +5,7 @@ #import #import #include +#include using namespace facebook::react; @@ -26,14 +27,19 @@ @implementation ETInstaller assert(jsiRuntime != nullptr); auto fetchUrl = [](std::string url) { - NSString *nsUrlStr = - [NSString stringWithCString:url.c_str() - encoding:[NSString defaultCStringEncoding]]; - NSURL *nsUrl = [NSURL URLWithString:nsUrlStr]; - NSData *data = [NSData dataWithContentsOfURL:nsUrl]; - const std::byte *bytePtr = reinterpret_cast(data.bytes); - int bufferLength = [data length]; - return std::vector(bytePtr, bytePtr + bufferLength); + @try { + NSString *nsUrlStr = + [NSString stringWithCString:url.c_str() + encoding:[NSString defaultCStringEncoding]]; + NSURL *nsUrl = [NSURL URLWithString:nsUrlStr]; + NSData *data = [NSData dataWithContentsOfURL:nsUrl]; + const std::byte *bytePtr = + reinterpret_cast(data.bytes); + int bufferLength = [data length]; + return std::vector(bytePtr, bytePtr + bufferLength); + } @catch (NSException *exception) { + throw std::runtime_error("Error fetching data from a url"); + } }; rnexecutorch::RnExecutorchInstaller::injectJSIBindings( jsiRuntime, jsCallInvoker, fetchUrl);