diff --git a/extension/android/BUCK b/extension/android/BUCK index 191e6ce4714..b02003fdc34 100644 --- a/extension/android/BUCK +++ b/extension/android/BUCK @@ -10,6 +10,7 @@ non_fbcode_target(_kind = fb_android_library, "executorch_android/src/main/java/org/pytorch/executorch/DType.java", "executorch_android/src/main/java/org/pytorch/executorch/EValue.java", "executorch_android/src/main/java/org/pytorch/executorch/ExecuTorchRuntime.java", + "executorch_android/src/main/java/org/pytorch/executorch/ExecutorchRuntimeException.java", "executorch_android/src/main/java/org/pytorch/executorch/MethodMetadata.java", "executorch_android/src/main/java/org/pytorch/executorch/Module.java", "executorch_android/src/main/java/org/pytorch/executorch/Tensor.java", diff --git a/extension/android/CMakeLists.txt b/extension/android/CMakeLists.txt index 38d30854525..be6715f93d5 100644 --- a/extension/android/CMakeLists.txt +++ b/extension/android/CMakeLists.txt @@ -71,6 +71,7 @@ executorch_target_link_options_shared_lib(executorch) add_library( executorch_jni SHARED jni/jni_layer.cpp jni/log.cpp jni/jni_layer_runtime.cpp + jni/jni_helper.cpp ) set(link_libraries) diff --git a/extension/android/executorch_android/src/main/java/org/pytorch/executorch/ExecutorchRuntimeException.java b/extension/android/executorch_android/src/main/java/org/pytorch/executorch/ExecutorchRuntimeException.java new file mode 100644 index 00000000000..de823f40afb --- /dev/null +++ b/extension/android/executorch_android/src/main/java/org/pytorch/executorch/ExecutorchRuntimeException.java @@ -0,0 +1,125 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +package org.pytorch.executorch; + +import java.util.Collections; +import java.util.HashMap; +import java.util.Map; + +public class ExecutorchRuntimeException extends RuntimeException { + // Error code constants - keep in sync with runtime/core/error.h + // System errors + public static final int OK = 0x00; + public static final int INTERNAL = 0x01; + public static final int INVALID_STATE = 0x02; + public static final int END_OF_METHOD = 0x03; + + // Logical errors + public static final int NOT_SUPPORTED = 0x10; + public static final int NOT_IMPLEMENTED = 0x11; + public static final int INVALID_ARGUMENT = 0x12; + public static final int INVALID_TYPE = 0x13; + public static final int OPERATOR_MISSING = 0x14; + public static final int REGISTRATION_EXCEEDING_MAX_KERNELS = 0x15; + public static final int REGISTRATION_ALREADY_REGISTERED = 0x16; + + // Resource errors + public static final int NOT_FOUND = 0x20; + public static final int MEMORY_ALLOCATION_FAILED = 0x21; + public static final int ACCESS_FAILED = 0x22; + public static final int INVALID_PROGRAM = 0x23; + public static final int INVALID_EXTERNAL_DATA = 0x24; + public static final int OUT_OF_RESOURCES = 0x25; + + // Delegate errors + public static final int DELEGATE_INVALID_COMPATIBILITY = 0x30; + public static final int DELEGATE_MEMORY_ALLOCATION_FAILED = 0x31; + public static final int DELEGATE_INVALID_HANDLE = 0x32; + + private static final Map ERROR_CODE_MESSAGES; + + static { + Map map = new HashMap<>(); + + // System errors + map.put(OK, "Operation successful"); + map.put(INTERNAL, "Internal error"); + map.put(INVALID_STATE, "Invalid state"); + map.put(END_OF_METHOD, "End of method reached"); + // Logical errors + map.put(NOT_SUPPORTED, "Operation not supported"); + map.put(NOT_IMPLEMENTED, "Operation not implemented"); + map.put(INVALID_ARGUMENT, "Invalid argument"); + map.put(INVALID_TYPE, "Invalid type"); + map.put(OPERATOR_MISSING, "Operator missing"); + map.put(REGISTRATION_EXCEEDING_MAX_KERNELS, "Exceeded max kernels"); + map.put(REGISTRATION_ALREADY_REGISTERED, "Kernel already registered"); + // Resource errors + map.put(NOT_FOUND, "Resource not found"); + map.put(MEMORY_ALLOCATION_FAILED, "Memory allocation failed"); + map.put(ACCESS_FAILED, "Access failed"); + map.put(INVALID_PROGRAM, "Invalid program"); + map.put(INVALID_EXTERNAL_DATA, "Invalid external data"); + map.put(OUT_OF_RESOURCES, "Out of resources"); + // Delegate errors + map.put(DELEGATE_INVALID_COMPATIBILITY, "Delegate invalid compatibility"); + map.put(DELEGATE_MEMORY_ALLOCATION_FAILED, "Delegate memory allocation failed"); + map.put(DELEGATE_INVALID_HANDLE, "Delegate invalid handle"); + ERROR_CODE_MESSAGES = Collections.unmodifiableMap(map); + } + + static class ErrorHelper { + static String formatMessage(int errorCode, String details) { + String baseMessage = ERROR_CODE_MESSAGES.get(errorCode); + if (baseMessage == null) { + baseMessage = "Unknown error code 0x" + Integer.toHexString(errorCode); + } + return "[Executorch Error 0x" + + Integer.toHexString(errorCode) + + "] " + + baseMessage + + ": " + + details; + } + } + + private final int errorCode; + + public ExecutorchRuntimeException(int errorCode, String details) { + super(ErrorHelper.formatMessage(errorCode, details)); + this.errorCode = errorCode; + } + + public int getErrorCode() { + return errorCode; + } + + // Idiomatic Java exception for invalid arguments. + public static class ExecutorchInvalidArgumentException extends IllegalArgumentException { + private final int errorCode = INVALID_ARGUMENT; + + public ExecutorchInvalidArgumentException(String details) { + super(ErrorHelper.formatMessage(INVALID_ARGUMENT, details)); + } + + public int getErrorCode() { + return errorCode; + } + } + + // Factory method to create an exception of the appropriate subclass. + public static RuntimeException makeExecutorchException(int errorCode, String details) { + switch (errorCode) { + case INVALID_ARGUMENT: + return new ExecutorchInvalidArgumentException(details); + default: + return new ExecutorchRuntimeException(errorCode, details); + } + } +} diff --git a/extension/android/jni/BUCK b/extension/android/jni/BUCK index 2a903da3e33..0ba39a71666 100644 --- a/extension/android/jni/BUCK +++ b/extension/android/jni/BUCK @@ -28,7 +28,7 @@ non_fbcode_target(_kind = executorch_generated_lib, non_fbcode_target(_kind = fb_android_cxx_library, name = "executorch_jni", - srcs = ["jni_layer.cpp", "log.cpp", "jni_layer_runtime.cpp"], + srcs = ["jni_layer.cpp", "log.cpp", "jni_layer_runtime.cpp", "jni_helper.cpp"], allow_jni_merging = False, compiler_flags = ET_JNI_COMPILER_FLAGS, soname = "libexecutorch.$(ext)", @@ -49,7 +49,7 @@ non_fbcode_target(_kind = fb_android_cxx_library, non_fbcode_target(_kind = fb_android_cxx_library, name = "executorch_jni_full", - srcs = ["jni_layer.cpp", "log.cpp", "jni_layer_runtime.cpp"], + srcs = ["jni_layer.cpp", "log.cpp", "jni_layer_runtime.cpp", "jni_helper.cpp"], allow_jni_merging = False, compiler_flags = ET_JNI_COMPILER_FLAGS, soname = "libexecutorch.$(ext)", @@ -71,7 +71,7 @@ non_fbcode_target(_kind = fb_android_cxx_library, non_fbcode_target(_kind = fb_android_cxx_library, name = "executorch_training_jni", - srcs = ["jni_layer.cpp", "log.cpp", "jni_layer_runtime.cpp", "jni_layer_training.cpp"], + srcs = ["jni_layer.cpp", "log.cpp", "jni_layer_runtime.cpp", "jni_layer_training.cpp", "jni_helper.cpp"], allow_jni_merging = False, compiler_flags = ET_JNI_COMPILER_FLAGS + [ "-DEXECUTORCH_BUILD_EXTENSION_TRAINING", @@ -98,11 +98,7 @@ non_fbcode_target(_kind = fb_android_cxx_library, non_fbcode_target(_kind = fb_android_cxx_library, name = "executorch_llama_jni", - srcs = [ - "jni_layer.cpp", - "jni_layer_llama.cpp", - "jni_layer_runtime.cpp", - ], + srcs = ["jni_layer.cpp", "jni_layer_llama.cpp", "jni_layer_runtime.cpp", "jni_helper.cpp"], allow_jni_merging = False, compiler_flags = ET_JNI_COMPILER_FLAGS + [ "-DEXECUTORCH_BUILD_LLAMA_JNI", @@ -145,9 +141,14 @@ runtime.export_file( name = "jni_layer_runtime.cpp", ) +runtime.export_file( + name = "jni_helper.cpp", +) + runtime.cxx_library( name = "jni_headers", exported_headers = [ "jni_layer_constants.h", + "jni_helper.h", ] ) diff --git a/extension/android/jni/jni_helper.cpp b/extension/android/jni/jni_helper.cpp new file mode 100644 index 00000000000..a8fb2aeddcf --- /dev/null +++ b/extension/android/jni/jni_helper.cpp @@ -0,0 +1,34 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include "jni_helper.h" + +namespace executorch::jni_helper { + +void throwExecutorchException(uint32_t errorCode, const std::string& details) { + // Get the current JNI environment + auto env = facebook::jni::Environment::current(); + + // Find the Java ExecutorchRuntimeException class + static auto exceptionClass = facebook::jni::findClassLocal( + "org/pytorch/executorch/ExecutorchRuntimeException"); + + // Find the static factory method: makeExecutorchException(int, String) + static auto makeExceptionMethod = exceptionClass->getStaticMethod< + facebook::jni::local_ref( + int, facebook::jni::alias_ref)>( + "makeExecutorchException", + "(ILjava/lang/String;)Lorg/pytorch/executorch/ExecutorchRuntimeException;"); + + auto jDetails = facebook::jni::make_jstring(details); + // Call the factory method to create the exception object + auto exception = makeExceptionMethod(exceptionClass, errorCode, jDetails); + facebook::jni::throwNewJavaException(exception.get()); +} + +} // namespace executorch::jni_helper diff --git a/extension/android/jni/jni_helper.h b/extension/android/jni/jni_helper.h new file mode 100644 index 00000000000..996d75581d3 --- /dev/null +++ b/extension/android/jni/jni_helper.h @@ -0,0 +1,26 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include +#include + +namespace executorch::jni_helper { + +/** + * Throws a Java ExecutorchRuntimeException corresponding to the given error + * code and details. Uses the Java factory method + * ExecutorchRuntimeException.makeExecutorchException(int, String). + * + * @param errorCode The error code from the C++ Executorch runtime. + * @param details Additional details to include in the exception message. + */ +void throwExecutorchException(uint32_t errorCode, const std::string& details); + +} // namespace executorch::jni_helper diff --git a/extension/android/jni/jni_layer.cpp b/extension/android/jni/jni_layer.cpp index 7111a0bc6bc..7ad54ffc360 100644 --- a/extension/android/jni/jni_layer.cpp +++ b/extension/android/jni/jni_layer.cpp @@ -6,7 +6,9 @@ * LICENSE file in the root directory of this source tree. */ +#include #include + #include #include #include @@ -55,14 +57,14 @@ class TensorHybrid : public facebook::jni::HybridClass { // Java wrapper currently only supports contiguous tensors. const auto scalarType = tensor.scalar_type(); - + int jdtype = scalar_type_to_java_dtype.at(scalarType); if (scalar_type_to_java_dtype.count(scalarType) == 0) { - facebook::jni::throwNewJavaException( - facebook::jni::gJavaLangIllegalArgumentException, - "executorch::aten::Tensor scalar type %d is not supported on java side", - scalarType); + std::stringstream ss; + ss << "executorch::aten::Tensor scalar [java] type: " << jdtype + << " is not supported on java side"; + jni_helper::throwExecutorchException( + static_cast(Error::InvalidArgument), ss.str().c_str()); } - int jdtype = scalar_type_to_java_dtype.at(scalarType); const auto& tensor_shape = tensor.sizes(); std::vector tensor_shape_vec; @@ -124,19 +126,19 @@ class TensorHybrid : public facebook::jni::HybridClass { } JNIEnv* jni = facebook::jni::Environment::current(); if (java_dtype_to_scalar_type.count(jdtype) == 0) { - facebook::jni::throwNewJavaException( - facebook::jni::gJavaLangIllegalArgumentException, - "Unknown Tensor jdtype %d", - jdtype); + std::stringstream ss; + ss << "Unknown Tensor jdtype: [" << jdtype << "]"; + jni_helper::throwExecutorchException( + static_cast(Error::InvalidArgument), ss.str().c_str()); } ScalarType scalar_type = java_dtype_to_scalar_type.at(jdtype); const auto dataCapacity = jni->GetDirectBufferCapacity(jbuffer.get()); if (dataCapacity != numel) { - facebook::jni::throwNewJavaException( - facebook::jni::gJavaLangIllegalArgumentException, - "Tensor dimensions(elements number:%d inconsistent with buffer capacity(%d)", - numel, - dataCapacity); + std::stringstream ss; + ss << "Tensor dimensions(elements number: " << numel + << "inconsistent with buffer capacity " << dataCapacity << "]"; + jni_helper::throwExecutorchException( + static_cast(Error::InvalidArgument), ss.str().c_str()); } return from_blob( jni->GetDirectBufferAddress(jbuffer.get()), shape_vec, scalar_type); @@ -194,10 +196,11 @@ class JEValue : public facebook::jni::JavaClass { return jMethodTensor( JEValue::javaClassStatic(), facebook::jni::make_jstring(str)); } - facebook::jni::throwNewJavaException( - facebook::jni::gJavaLangIllegalArgumentException, - "Unsupported EValue type: %d", - evalue.tag); + std::stringstream ss; + ss << "Unknown EValue type: [" << static_cast(evalue.tag) << "]"; + jni_helper::throwExecutorchException( + static_cast(Error::InvalidArgument), ss.str().c_str()); + return {}; } static TensorPtr JEValueToTensorImpl( @@ -213,10 +216,11 @@ class JEValue : public facebook::jni::JavaClass { auto jtensor = jMethodGetTensor(JEValue); return TensorHybrid::newTensorFromJTensor(jtensor); } - facebook::jni::throwNewJavaException( - facebook::jni::gJavaLangIllegalArgumentException, - "Unknown EValue typeCode %d", - typeCode); + std::stringstream ss; + ss << "Unknown EValue typeCode: " << typeCode; + jni_helper::throwExecutorchException( + static_cast(Error::InvalidArgument), ss.str().c_str()); + return {}; } }; @@ -296,13 +300,26 @@ class ExecuTorchJni : public facebook::jni::HybridClass { jinputs) { // If no inputs is given, it will run with sample inputs (ones) if (jinputs->size() == 0) { - if (module_->load_method(method) != Error::Ok) { + auto result = module_->load_method(method); + if (result != Error::Ok) { + // Format hex string + std::stringstream ss; + ss << "Cannot get method names [Native Error: 0x" << std::hex + << std::uppercase << static_cast(result) << "]"; + + jni_helper::throwExecutorchException( + static_cast( + Error::InvalidArgument), // For backward compatibility + ss.str()); return {}; } auto&& underlying_method = module_->methods_[method].method; auto&& buf = prepare_input_tensors(*underlying_method); - auto result = underlying_method->execute(); + result = underlying_method->execute(); if (result != Error::Ok) { + jni_helper::throwExecutorchException( + static_cast(result), + "Execution failed for method: " + method); return {}; } facebook::jni::local_ref> jresult = @@ -356,11 +373,9 @@ class ExecuTorchJni : public facebook::jni::HybridClass { #endif if (!result.ok()) { - facebook::jni::throwNewJavaException( - "java/lang/Exception", - "Execution of method %s failed with status 0x%" PRIx32, - method.c_str(), - static_cast(result.error())); + jni_helper::throwExecutorchException( + static_cast(result.error()), + "Execution failed for method: " + method); return {}; } @@ -438,9 +453,17 @@ class ExecuTorchJni : public facebook::jni::HybridClass { facebook::jni::local_ref> getMethods() { const auto& names_result = module_->method_names(); if (!names_result.ok()) { - facebook::jni::throwNewJavaException( - facebook::jni::gJavaLangIllegalArgumentException, - "Cannot get load module"); + // Format hex string + std::stringstream ss; + ss << "Cannot get load module [Native Error: 0x" << std::hex + << std::uppercase << static_cast(names_result.error()) + << "]"; + + jni_helper::throwExecutorchException( + static_cast( + Error::InvalidArgument), // For backward compatibility + ss.str()); + return {}; } const auto& methods = names_result.get(); facebook::jni::local_ref> ret = diff --git a/extension/android/jni/selective_jni.buck.bzl b/extension/android/jni/selective_jni.buck.bzl index d557606b7d1..8e20f903ca9 100644 --- a/extension/android/jni/selective_jni.buck.bzl +++ b/extension/android/jni/selective_jni.buck.bzl @@ -10,6 +10,7 @@ def selective_jni_target(name, deps, srcs = [], soname = "libexecutorch.$(ext)") srcs = [ "//xplat/executorch/extension/android/jni:jni_layer.cpp", "//xplat/executorch/extension/android/jni:jni_layer_runtime.cpp", + "//xplat/executorch/extension/android/jni:jni_helper.cpp", ] + srcs, allow_jni_merging = False, compiler_flags = ET_JNI_COMPILER_FLAGS,